#!/usr/bin/env python
# coding: utf-8

# In[14]:


from datetime import datetime, timedelta
from netCDF4 import Dataset
import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import cmocean.cm as cmo
from matplotlib import ticker 
import glob 
import matplotlib as mpl
from wrf import getvar, interplevel, get_cartopy, vertcross
from matplotlib.colors import LogNorm
import sys
from string import ascii_lowercase

mpl.rcParams['figure.figsize'] = [10,10]
mpl.rcParams['figure.titlesize'] = 10
mpl.rcParams['figure.titleweight'] = 'bold'
mpl.rcParams['xtick.labelsize'] = 11
mpl.rcParams['ytick.labelsize'] = 11
mpl.rcParams['axes.labelsize'] = 11
mpl.rcParams['axes.titlesize'] = 10
mpl.rcParams['lines.linewidth'] = 1.8
mpl.rcParams['grid.linewidth'] = .25
mpl.rcParams['figure.subplot.wspace'] = 0.05
mpl.rcParams['figure.subplot.hspace'] = 0.05
mpl.rcParams['legend.fontsize'] = 11
mpl.rcParams['legend.framealpha'] = .75
mpl.rcParams['legend.loc'] = 'best'
mpl.rcParams['savefig.bbox'] = 'tight'
mpl.rcParams['savefig.dpi'] = 600

def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero.

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower offset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax / (vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highest point in the colormap's range.
          Defaults to 1.0 (no upper offset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap

rdir = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting_no_feedback/met/'
#rdir2 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting/turb/'
rdir3 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting_no_feedback/fire/'
#rdir = '/glade/scratch/tjuliano/fire/sagehen/WRF-Fire-LEAPHI/WRF/test/marshall/2021123012/run/netcdf_2/'
dx = 111.111/1000


# ## For making the vector plots xs without the vectors

# In[ ]:



lowerc = list(iter(ascii_lowercase))
from matplotlib.colors import LogNorm
#for tidx in np.arange(240, 271, 30):#[240, 360, 480]:
sdir = 'xs/for_paper/xsect_u/'
loc = []
#for tidx in np.arange(480, 541, 10):#[240, 360, 480]:
#offset = [-20,-15,-10,-5,0,5,10,15,20]
offset = [-10]
dd = [30,30,30,30,31]
hh = [19,19,22,23,2]
mm = [0,45,5,0,30]
xlon_fr = [-105.05676,-105.01111,-105.16763,-105.1911,-105.17415]
for tidx in np.arange(len(dd)):
    time = datetime(2021, 12, dd[tidx], hh[tidx], mm[tidx], 0)
#    time = time + timedelta(minutes=int(tidx))
    filename = 'wrfout_d02_%s' % time.strftime('%Y-%m-%d_%H:%M:%S')
    print (filename)
    with Dataset(rdir + filename) as nc:
        timestring = nc.variables['Times']
        xlat = getvar(nc, 'lat').values
        xlon = getvar(nc, 'lon')
        
    #        data['fmc_g'] = nc.variables['FMC_G'][tidx,:]
        u        = getvar(nc, 'ua').values
        th       = getvar(nc, 'T').values
        th       = th + 300.
    #    tke      = getvar(nc, 'TKE').values
        ph       = getvar(nc, 'PH').values
        phb      = getvar(nc,'PHB').values
        zht      = ((ph[1:]+ph[0:-1])/2. +(phb[1:]+phb[0:-1])/2.)/9.81
        hgt      = getvar(nc,'HGT').values

#    with Dataset(rdir2 + filename) as nc:
#        tke      = getvar(nc, 'TKE').values

    with Dataset(rdir3 + filename) as nc:
        farea   = getvar(nc, 'FIRE_AREA').values
        fxlon   = getvar(nc, 'FXLONG').values

        
#    winterp = interplevel()
#    uinterp = interplevel()
    
    indx = np.argmax(np.array(u))
    loc.append(np.unravel_index(indx, np.shape(u)))

    farea_rav = farea.ravel()
    xlon_rav = fxlon.ravel()

    if (np.nanmax(farea_rav) > 0.0):
        farea_idx = np.where(farea_rav==0.0)[0]
        xlon_rav[farea_idx] = np.nan
        xlon_fire = np.nanmax(xlon_rav)
#        print (xlon_fire)
    else:
        xlon_fire = -105.230189

    for kk in np.arange(len(offset)):
        # cross section over the fire
        x = 262 + 15 + offset[kk]
        y = 350

        #print (np.mean(xlat[x,:]), np.mean(xlat[x+10,:]), np.mean(xlat[x-10,:]))
    
        xlonv = np.meshgrid(xlon[x, :], range(u.shape[0]))[0]
        zv    = np.squeeze(zht[:,x, :])
        uv    = np.squeeze(u[:,x,:])
#        tkev  = np.squeeze(tke[:,x,:])
        thv   = np.squeeze(th[:,x,:])
#        fsmokev = np.squeeze(fsmoke[:,x,:])
        hgtv  = hgt[x, :]

        ### COMPUTE PBLH
        pblh = np.empty(np.shape(thv)[1])
        for jj in np.arange(np.shape(thv)[1]):
            thv_grad = (thv[1:36,jj] - thv[0:35,jj]) / (zv[1:36,jj] - zv[0:35,jj])
            thv_grad_max = np.argmax(thv_grad)
            #thv_grad_max2 = np.argsort(thv_grad)
            #print (thv_grad)
            #print (thv_grad_max1, thv_grad_max2)
            pblh[jj] = zv[thv_grad_max,jj] - hgtv[jj]

        if tidx == 0:
            fig = plt.figure(figsize=(12, 9))
        axs = plt.subplot(3, 2, tidx+1)
#        fig.suptitle(time.strftime('%Y-%m-%d %H:%M:%SZ'))
    
        orig_cmap = cmo.balance
        shifted_cmap = shiftedColorMap(orig_cmap, midpoint=5./45., name='shifted')
        c0 = axs.contourf(xlonv, zv, uv,  cmap=shifted_cmap,levels=np.arange(-5, 42.5, 2.5), extend='both')
        c00 = axs.contour(xlonv, zv, thv, np.arange(0,502,2), linewidths=1., colors='limegreen')
        axs.plot([xlon_fire,xlon_fire],[1500,6000],lw=2.,ls='--',c='magenta')
        axs.plot([xlon_fr[tidx],xlon_fr[tidx]],[1500,6000],lw=2.,ls=':',c='magenta')
#        axs.plot(xlonv[0,:], pblh+hgtv, lw=2., c='magenta')
        axs.clabel(c00,c00.levels[::2],fmt='%d')
    
#        axs[0].invert_yaxis()
#        axs[0].set_title('Zonal Wind' )
#        axs.set_title(time.strftime('%Y-%m-%d %H:%M:%SZ') + ', ' + str(round(np.mean(xlat[x,:]),3)))
        axs.set_title(time.strftime('%Y-%m-%d %H:%M:%SZ'))
        axs.fill_between(xlon[x, :], 0, zv[0], facecolor="gray")

#        axs[0].set_ylim(axs[0].get_ylim()[0], 450)
        axs.set_ylim(1500, 6000)
#        axs[0].set_ylabel('Height AGL (m)')
        if tidx == 1 or tidx == 3:
            axs.set_yticks([2000,3000,4000,5000,6000])
            axs.set_yticklabels(['','','','',''])
        if tidx == 3 or tidx == 4:
            axs.set_xlabel('Longitude')
        if tidx == 0 or tidx == 2 or tidx == 4:
            axs.set_ylabel('Height ASL (m)')

plt.tight_layout()

cbar_ax = fig.add_axes([0.565, 0.2, 0.4, 0.05])
cbar = fig.colorbar(c0, cax=cbar_ax, orientation='horizontal')
cbar.ax.tick_params(labelsize=14)
cbar.set_ticks([-5,0,5,10,15,20,25,30,35,40])
cbar.ax.set_xlabel(r'$U$ (m s$^{-1}$)',fontsize=20,labelpad=20)
 
#        cbar0 = plt.colorbar(c0, format='%.0f', ax = axs)
#        cbar0.ax.set_ylabel(r'$U$ (m s$^{-1}$)')

plt.savefig('no_feedback_xs_lat' + str(kk)  + '_u_panel_%s' % time.strftime('%4Y-%m-%d_%H:%M:%S'), dpi=600, bbox_inches='tight')
plt.close()

# ## For making the GIF File

# In[99]:


from PIL import Image

sys.exit()

# filepaths
fp_in = "./xs/*.png"
fp_out = "./xs/xs_xx.gif"

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
#img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
#frames = [Image.open(image) for image in sorted(glob.glob(fp_in))]
#frame_one = frames[0]
all_files = sorted(glob.glob(fp_in))
frames = []
for i in np.arange(len(all_files)):
    frame = Image.open(all_files[i])
    frames.append(frame.copy())
#img.save(fp=fp_out, format='GIF', append_images=imgs,
#         save_all=True, duration=240, loop=0)
frames[0].save(fp=fp_out, format='GIF', append_images=frames[1:],
         save_all=True, duration=300, loop=0)


# In[ ]:




